{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/albert/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:34: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "成功加载语料data/example.train, 语料数量18404\n",
      "成功加载语料data/example.test, 语料数量4558\n",
      "成功加载语料data/example.dev, 语料数量4384\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.contrib import seq2seq\n",
    "from elmo import ELMo\n",
    "from data import NERData\n",
    "import os\n",
    "\n",
    "total_epoch = 5000\n",
    "hidden_size = 200\n",
    "vocab_size = 5000\n",
    "max_length = 128\n",
    "entity_class = 7\n",
    "\n",
    "batch_size = 1    # 跟训练的区别\n",
    "\n",
    "ner = NERData(batch_size, max_length)\n",
    "elmo = ELMo(batch_size, hidden_size, vocab_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def network(X):\n",
    "    w = tf.get_variable(\"fcn_w\", [1, hidden_size, entity_class + 1])\n",
    "    b = tf.get_variable(\"fcn_b\", [entity_class + 1])\n",
    "    # 这里输出维度用entity_class + 1而不是entity_class，因为输出里除了7类实体，还有一类用来表示每个句子补齐的<PAD>位\n",
    "    w_tile = tf.tile(w, [batch_size, 1, 1])\n",
    "\n",
    "    logists = tf.nn.softmax(tf.nn.xw_plus_b(X, w_tile, b), name=\"logists\")\n",
    "    return logists"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from ./model/model-150000\n"
     ]
    }
   ],
   "source": [
    "X = tf.placeholder(shape=[batch_size, max_length], dtype=tf.int32, name=\"X\")\n",
    "length = tf.placeholder(shape=[batch_size], dtype=tf.int32, name=\"length\")\n",
    "dropout = tf.placeholder(shape=[], dtype=tf.float32, name=\"dropout\")\n",
    "\n",
    "elmo_vector = elmo.elmo(X, length, dropout)\n",
    "logists = network(elmo_vector)\n",
    "\n",
    "gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)\n",
    "config = tf.ConfigProto(gpu_options=gpu_options)\n",
    "sess = tf.Session(config=config)\n",
    "\n",
    "sess.run(tf.global_variables_initializer())\n",
    "model_dir = \"./model\"\n",
    "saver = tf.train.Saver()\n",
    "check_point = tf.train.get_checkpoint_state(model_dir)\n",
    "saver.restore(sess, check_point.model_checkpoint_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 查看elmo各层的权重"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "embedding层: 0.00021379402\n",
      "forward_1层: 0.028424246\n",
      "forward_2层: 0.50437975\n",
      "backward_1层: 0.02824858\n",
      "backward_2层: 0.43873364\n"
     ]
    }
   ],
   "source": [
    "_X, _length, _targets = ner.get_dev_data(1)\n",
    "fd = {X: _X, length: _length, dropout: 1.}\n",
    "weights = sess.run(tf.nn.softmax(elmo.weights), feed_dict=fd)\n",
    "print(\"embedding层:\", weights[0])\n",
    "print(\"forward_1层:\", weights[1])\n",
    "print(\"forward_2层:\", weights[2])\n",
    "print(\"backward_1层:\", weights[3])\n",
    "print(\"backward_2层:\", weights[4])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 用来自不同领域的句子进行测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('李', 'B-PER'),\n",
       " ('克', 'I-PER'),\n",
       " ('强', 'I-PER'),\n",
       " ('来', 'O'),\n",
       " ('到', 'O'),\n",
       " ('位', 'O'),\n",
       " ('于', 'O'),\n",
       " ('江', 'B-LOC'),\n",
       " ('西', 'I-LOC'),\n",
       " ('省', 'I-LOC'),\n",
       " ('赣', 'B-LOC'),\n",
       " ('州', 'I-LOC'),\n",
       " ('市', 'I-LOC'),\n",
       " ('于', 'B-LOC'),\n",
       " ('都', 'I-LOC'),\n",
       " ('县', 'I-LOC'),\n",
       " ('的', 'O'),\n",
       " ('梓', 'B-LOC'),\n",
       " ('山', 'I-LOC'),\n",
       " ('镇', 'I-LOC'),\n",
       " ('潭', 'I-LOC'),\n",
       " ('头', 'I-LOC'),\n",
       " ('村', 'I-LOC'),\n",
       " ('看', 'O'),\n",
       " ('望', 'O'),\n",
       " ('慰', 'O'),\n",
       " ('问', 'O'),\n",
       " ('群', 'O'),\n",
       " ('众', 'O')]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s = \"李克强来到位于江西省赣州市于都县的梓山镇潭头村看望慰问群众\"\n",
    "_X, _length = ner.sentence_encode(s)\n",
    "fd = {X: _X, length: _length, dropout: 1.}\n",
    "result = sess.run(logists, feed_dict=fd)\n",
    "result_entities = ner.entities_decode(result[0].argmax(axis=1))\n",
    "list(zip(s, result_entities))[:_length[0]][:len(s)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('多', 'O'),\n",
       " ('名', 'O'),\n",
       " ('白', 'B-ORG'),\n",
       " ('宫', 'I-ORG'),\n",
       " ('官', 'O'),\n",
       " ('员', 'O'),\n",
       " ('对', 'O'),\n",
       " ('媒', 'O'),\n",
       " ('体', 'O'),\n",
       " ('表', 'O'),\n",
       " ('示', 'O'),\n",
       " ('，', 'O'),\n",
       " ('美', 'B-LOC'),\n",
       " ('国', 'I-LOC'),\n",
       " ('总', 'O'),\n",
       " ('统', 'O'),\n",
       " ('特', 'B-PER'),\n",
       " ('朗', 'I-LOC'),\n",
       " ('普', 'I-PER'),\n",
       " ('不', 'O'),\n",
       " ('准', 'O'),\n",
       " ('备', 'O'),\n",
       " ('续', 'O'),\n",
       " ('签', 'O'),\n",
       " ('即', 'O'),\n",
       " ('将', 'O'),\n",
       " ('到', 'O'),\n",
       " ('期', 'O'),\n",
       " ('的', 'O'),\n",
       " ('美', 'B-LOC'),\n",
       " ('俄', 'B-LOC'),\n",
       " ('《', 'O'),\n",
       " ('新', 'O'),\n",
       " ('削', 'O'),\n",
       " ('减', 'O'),\n",
       " ('战', 'O'),\n",
       " ('略', 'O'),\n",
       " ('武', 'O'),\n",
       " ('器', 'O'),\n",
       " ('条', 'O'),\n",
       " ('约', 'O'),\n",
       " ('》', 'O'),\n",
       " ('，', 'O'),\n",
       " ('而', 'O'),\n",
       " ('是', 'O'),\n",
       " ('想', 'O'),\n",
       " ('推', 'O'),\n",
       " ('动', 'O'),\n",
       " ('达', 'O'),\n",
       " ('成', 'O'),\n",
       " ('一', 'O'),\n",
       " ('项', 'O'),\n",
       " ('包', 'O'),\n",
       " ('括', 'O'),\n",
       " ('中', 'B-LOC'),\n",
       " ('国', 'I-LOC'),\n",
       " ('在', 'O'),\n",
       " ('内', 'O'),\n",
       " ('的', 'O'),\n",
       " ('新', 'O'),\n",
       " ('条', 'O'),\n",
       " ('约', 'O')]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s = \"多名白宫官员对媒体表示，美国总统特朗普不准备续签即将到期的美俄《新削减战略武器条约》，而是想推动达成一项包括中国在内的新条约\"\n",
    "_X, _length = ner.sentence_encode(s)\n",
    "fd = {X: _X, length: _length, dropout: 1.}\n",
    "result = sess.run(logists, feed_dict=fd)\n",
    "result_entities = ner.entities_decode(result[0].argmax(axis=1))\n",
    "list(zip(s, result_entities))[:_length[0]][:len(s)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('小', 'B-PER'),\n",
       " ('罗', 'B-PER'),\n",
       " ('伯', 'I-PER'),\n",
       " ('特', 'I-PER'),\n",
       " ('·', 'I-PER'),\n",
       " ('唐', 'I-PER'),\n",
       " ('尼', 'I-PER'),\n",
       " ('专', 'I-PER'),\n",
       " ('门', 'O'),\n",
       " ('为', 'O'),\n",
       " ('漫', 'O'),\n",
       " ('威', 'O'),\n",
       " ('总', 'O'),\n",
       " ('裁', 'O'),\n",
       " ('凯', 'B-PER'),\n",
       " ('文', 'I-PER'),\n",
       " ('费', 'I-PER'),\n",
       " ('奇', 'I-PER'),\n",
       " ('颁', 'O'),\n",
       " ('发', 'O'),\n",
       " ('特', 'O'),\n",
       " ('别', 'O'),\n",
       " ('大', 'O'),\n",
       " ('奖', 'O'),\n",
       " ('，', 'O'),\n",
       " ('他', 'O'),\n",
       " ('在', 'O'),\n",
       " ('介', 'O'),\n",
       " ('绍', 'O'),\n",
       " ('凯', 'B-PER'),\n",
       " ('文', 'I-PER'),\n",
       " ('费', 'I-PER'),\n",
       " ('奇', 'I-PER'),\n",
       " ('的', 'O'),\n",
       " ('时', 'O'),\n",
       " ('候', 'O'),\n",
       " ('说', 'O'),\n",
       " ('道', 'O'),\n",
       " ('“', 'O'),\n",
       " ('我', 'O'),\n",
       " ('要', 'O'),\n",
       " ('感', 'O'),\n",
       " ('谢', 'O'),\n",
       " ('凯', 'B-PER'),\n",
       " ('文', 'I-PER'),\n",
       " ('，', 'O'),\n",
       " ('他', 'O'),\n",
       " ('在', 'O'),\n",
       " ('我', 'O'),\n",
       " ('的', 'O'),\n",
       " ('低', 'O'),\n",
       " ('谷', 'O'),\n",
       " ('期', 'O'),\n",
       " ('认', 'O'),\n",
       " ('可', 'O'),\n",
       " ('我', 'O'),\n",
       " ('“', 'O')]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s = \"小罗伯特·唐尼专门为漫威总裁凯文费奇颁发特别大奖，他在介绍凯文费奇的时候说道“我要感谢凯文，他在我的低谷期认可我“\"\n",
    "_X, _length = ner.sentence_encode(s)\n",
    "fd = {X: _X, length: _length, dropout: 1.}\n",
    "result = sess.run(logists, feed_dict=fd)\n",
    "result_entities = ner.entities_decode(result[0].argmax(axis=1))\n",
    "list(zip(s, result_entities))[:_length[0]][:len(s)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('马', 'B-LOC'),\n",
       " ('苏', 'I-LOC'),\n",
       " ('、', 'O'),\n",
       " ('周', 'B-PER'),\n",
       " ('冬', 'I-PER'),\n",
       " ('雨', 'I-PER'),\n",
       " ('、', 'O'),\n",
       " ('林', 'B-PER'),\n",
       " ('更', 'I-PER'),\n",
       " ('新', 'I-PER'),\n",
       " ('、', 'O'),\n",
       " ('霍', 'B-PER'),\n",
       " ('建', 'I-PER'),\n",
       " ('华', 'I-PER'),\n",
       " ('都', 'O'),\n",
       " ('被', 'O'),\n",
       " ('拍', 'O'),\n",
       " ('到', 'O'),\n",
       " ('出', 'O'),\n",
       " ('现', 'O'),\n",
       " ('在', 'O'),\n",
       " ('周', 'B-PER'),\n",
       " ('杰', 'I-PER'),\n",
       " ('伦', 'I-PER'),\n",
       " ('的', 'O'),\n",
       " ('演', 'O'),\n",
       " ('唱', 'O'),\n",
       " ('会', 'O'),\n",
       " ('上', 'O'),\n",
       " ('，', 'O'),\n",
       " ('此', 'O'),\n",
       " ('外', 'O'),\n",
       " ('，', 'O'),\n",
       " ('林', 'B-PER'),\n",
       " ('俊', 'I-PER'),\n",
       " ('杰', 'I-PER'),\n",
       " ('和', 'O'),\n",
       " ('陈', 'B-PER'),\n",
       " ('奕', 'I-PER'),\n",
       " ('迅', 'I-PER'),\n",
       " ('等', 'O'),\n",
       " ('超', 'O'),\n",
       " ('级', 'O'),\n",
       " ('巨', 'O'),\n",
       " ('星', 'O'),\n",
       " ('也', 'O'),\n",
       " ('都', 'O'),\n",
       " ('作', 'O'),\n",
       " ('为', 'O'),\n",
       " ('演', 'O'),\n",
       " ('唱', 'O'),\n",
       " ('会', 'O'),\n",
       " ('嘉', 'O'),\n",
       " ('宾', 'O'),\n",
       " ('出', 'O'),\n",
       " ('现', 'O'),\n",
       " ('在', 'O'),\n",
       " ('演', 'O'),\n",
       " ('唱', 'O'),\n",
       " ('会', 'O'),\n",
       " ('上', 'O')]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s = \"马苏、周冬雨、林更新、霍建华都被拍到出现在周杰伦的演唱会上，此外，林俊杰和陈奕迅等超级巨星也都作为演唱会嘉宾出现在演唱会上\"\n",
    "_X, _length = ner.sentence_encode(s)\n",
    "fd = {X: _X, length: _length, dropout: 1.}\n",
    "result = sess.run(logists, feed_dict=fd)\n",
    "result_entities = ner.entities_decode(result[0].argmax(axis=1))\n",
    "list(zip(s, result_entities))[:_length[0]][:len(s)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
